//! [`Extension`] support for inserting or extracting anything for contexts

use motore::{layer::Layer, service::Service};
use volo::context::Context;

/// Inserting anything into contexts as a [`Layer`] or extracting anything as an extractor
///
/// # Examples
///
/// ```
/// use volo_http::{
///     server::route::{get, Router},
///     utils::Extension,
/// };
///
/// #[derive(Clone)]
/// struct State {
///     foo: String,
/// }
///
/// // A handler for extracting the `State` from `Extension`
/// async fn show_state(Extension(state): Extension<State>) -> String {
///     state.foo
/// }
///
/// let router: Router = Router::new()
///     .route("/", get(show_state))
///     // Use `Extension` as a `Layer`
///     .layer(Extension(State {
///         foo: String::from("bar"),
///     }));
/// ```
#[derive(Debug, Default, Clone, Copy)]
pub struct Extension<T>(pub T);

impl<S, T> Layer<S> for Extension<T>
where
    S: Send + Sync + 'static,
    T: Sync,
{
    type Service = ExtensionService<S, T>;

    fn layer(self, inner: S) -> Self::Service {
        ExtensionService { inner, ext: self.0 }
    }
}

/// A [`Service`] generated by [`Extension`] as a [`Layer`] for inserting something into Contexts.
#[derive(Debug, Default, Clone, Copy)]
pub struct ExtensionService<I, T> {
    inner: I,
    ext: T,
}

impl<S, Cx, Req, Resp, E, T> Service<Cx, Req> for ExtensionService<S, T>
where
    S: Service<Cx, Req, Response = Resp, Error = E> + Send + Sync + 'static,
    Req: Send,
    Cx: Context + Send,
    T: Clone + Send + Sync + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
        cx.extensions_mut().insert(self.ext.clone());
        self.inner.call(cx, req).await
    }
}

#[cfg(feature = "server")]
mod server {
    use http::{request::Parts, StatusCode};
    use volo::context::Context;

    use super::Extension;
    use crate::{
        context::ServerContext,
        response::Response,
        server::{extract::FromContext, IntoResponse},
    };

    impl<T> FromContext for Extension<T>
    where
        T: Clone + Send + Sync + 'static,
    {
        type Rejection = ExtensionRejection;

        async fn from_context(
            cx: &mut ServerContext,
            _parts: &mut Parts,
        ) -> Result<Self, Self::Rejection> {
            cx.extensions()
                .get::<T>()
                .cloned()
                .map(Extension)
                .ok_or(ExtensionRejection::NotExist)
        }
    }

    pub enum ExtensionRejection {
        NotExist,
    }

    impl IntoResponse for ExtensionRejection {
        fn into_response(self) -> Response {
            StatusCode::INTERNAL_SERVER_ERROR.into_response()
        }
    }
}
